【BQML応用記事】BigQuery MLで作った機械学習のモデルでオンライン予測を実施する
こんにちは、Mr.Moです。
このエントリは『クラスメソッド BigQuery Advent Calendar 2020』25本目のエントリです。本アドベントカレンダーもこのエントリで最後とのことで、これまで弊社のメンバーで書いてきた記事もご覧いただきながら最後までお楽しみいただければと思います。
- クラスメソッド BigQuery Advent Calendar 2020 の記事一覧 | Developers.IO
- クラスメソッド BigQuery Advent Calendar 2020 - Qiita
当エントリではBigQuery ML(以下、BQML)のさらなる実用的な使い方である、オンライン予測ができるまでを見ていこうと思います。
BQMLとは
BQMLの概要については下記の記事にまとめてますのでご覧いただければと思います。
オンライン予測ができるまでをやっていく
今回はkaggleのタイタニックのデータを使って生存者の予測をするというお題で進めていこうと思います。
データの内容は下記のようになっています。このデータからタイタニック号での生存を予測(survival
の0,1を予測)しようというものです。
変数 | 定義 | key |
---|---|---|
survival | 生存 | 1 = 生存, 0 = 生存しない |
pclass | チケットのクラス | 1 = 1st, 2 = 2nd, 3 = 3rd |
sex | 性別 | |
age | 年齢 | |
sibsp | タイタニック号に乗っている兄弟/配偶者の数 | |
parch | タイタニック号に乗っている親/子の数 | |
ticket | チケット番号 | |
fare | 旅客運賃 | |
cabin | キャビン番号 | |
embarked | 乗船した港 | C = Cherbourg, Q = Queenstown, S = Southampton |
先にトレーニングデータ(train.csv)をデータセットに追加しておきます。
モデルのトレーニング
先ほど追加したデータを使ってトレーニングを行います。BQMLならSQLで簡単にトレーニングも実施できますね。 使うモデルはXGBoostでいこうと思います。(ちなみに最初はAutoML Tablesを使う予定でしたが、オンライン予測はまだ対応していませんでした)
CREATE OR REPLACE MODEL Titanic.xgboost_model OPTIONS( MODEL_TYPE='boosted_tree_classifier', INPUT_LABEL_COLS=["Survived"] ) AS SELECT * EXCEPT(PassengerId, Name, Ticket, Fare, Cabin) FROM `Titanic.train`
モデルのエクスポート
モデルのトレーニングが完了したら、Cloud Storage バケットにモデルをエクスポートします。 Cloud Shell上で下記のコマンドを実行します。
$ gsutil mb gs://titanic-2020 $ bq extract --destination_format ML_XGBOOST_BOOSTER -m Titanic.xgboost_model gs://titanic-2020/xgboost_model
モデルのデプロイ
そしてモデルをデプロイしていきます。デプロイ以降はAI Platformサービスを使います。
モデルリソースを作成し(ここでregionの選択を求められますが[1] global
を選択しました)
$ MODEL_NAME="TITANIC_XGBOOST_MODEL" $ gcloud ai-platform models create $MODEL_NAME
モデル バージョンを作成して
$ MODEL_DIR="gs://titanic-2020/xgboost_model" $ VERSION_NAME="v1" $ gcloud beta ai-platform versions create $VERSION_NAME --model=$MODEL_NAME --origin=$MODEL_DIR --package-uris=${MODEL_DIR}/xgboost_predictor-0.1.tar.gz--prediction-class=predictor.Predictor --runtime-version=1.15 --machine-type="mls1-c1-m2"
エラーが出なければOKです。 さっそくコマンドで予測を実行してみましょう。
$ vi instances.json $ INPUT_DATA_FILE="instances.json" $ cat instances.json {"Pclass":2, "Sex":"male", "Age":33, "SibSp":1, "Parch":2, "Embarked":"Q"} $ gcloud ai-platform predict --model $MODEL_NAME --version $VERSION_NAME --json-instances $INPUT_DATA_FILE
無事、予測結果が返ってきましたね!
サービスアカウントの作成
もう少し色んなところで使えるようにしていきたいと思います。curlやプログラムでも作った予測モデルを使えるようにサービスアカウントを作成して実行権限を付与していきます。
そしてjson形式のkey fileを作成・ダウンロードします。(ai-work-275303-xxx.json
のファイル)
curlで予測実行
下準備が整いましたのでまずは軽くcurlから予測を実行してみます。
$ export GOOGLE_APPLICATION_CREDENTIALS="/home/takashi1_kawamoto/ai-work-275303-b0d2af79eeab.json" $ export YOUR_PROJECT_ID="ai-work-275303" $ export YOUR_MODEL_NAME="TITANIC_XGBOOST_MODEL" $ curl -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ -H "Content-Type: application/json" \ -X POST \ -d '{"instances":[{"Pclass":2, "Sex":"male", "Age":33, "SibSp":1, "Parch":2, "Embarked":"Q"}]}' \ https://ml.googleapis.com/v1/projects/${YOUR_PROJECT_ID}/models/${YOUR_MODEL_NAME}:predict
アプリに組み込んで予測実行
ここまででも予測は実行できてますが、やはりアプリっぽいものから実行する方が楽しいと思うので対応していきます。LINEのBot(Messaging API)なら簡単に要件を満たせるのでこの方向で進めていきましょう。
ここから先はアプリ化に必要な流れや情報をかなり簡単に記載していきます。(雰囲気だけ感じていただければ、もしこの情報だけで手を動かせそうでしたらぜひ!)
下記にざっくり必要なものを記載します。
- LINE Bot(Messaging API)
- Cloud Run
- Datastore
- 何かしらのPythonで開発できる環境(私は簡単なのでGitHub Codespacesを使いました♪)
フォルダ構成は下記です。
. ├── ai-work-275303-b0d2af79eeab.json ├── app.py └── Dockerfile
各ファイルの中身は下記です。
- app.py
import json import os from flask import Flask, abort, request from google.cloud import datastore from googleapiclient import discovery from oauth2client.client import GoogleCredentials from linebot import ( LineBotApi, WebhookHandler ) from linebot.exceptions import ( InvalidSignatureError ) from linebot.models import ( MessageEvent, TextMessage, TextSendMessage, StickerSendMessage, PostbackEvent, PostbackAction, QuickReply, QuickReplyButton ) client = datastore.Client() app = Flask(__name__) line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN', None)) handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET', None)) sa_keyfile = os.getcwd() + '/ai-work-275303-63c1bb1f0331.json' PROJECT_ID = os.getenv('PROJECT_ID', None) MODEL_NAME = os.getenv('MODEL_NAME', None) URL = 'https://ml.googleapis.com/v1/projects/{}/models/{}:predict' @app.route('/callback', methods=['POST']) def callback(): # get X-Line-Signature header value signature = request.headers['X-Line-Signature'] # get request body as text body = request.get_data(as_text=True) app.logger.info('Request body: ' + body) # handle webhook body try: handler.handle(body, signature) except InvalidSignatureError: print('Invalid signature. Please check your channel access token/channel secret.') abort(400) return 'OK' @handler.add(MessageEvent, message=TextMessage) def handle_message(event): key = client.key('Titanic', event.source.user_id) entity = datastore.Entity(key=key) result = client.get(key) if event.message.text == '予測': entity.update({ 'question': '0', }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage( text='チケットのクラスはどれですか?', quick_reply=QuickReply( items=[ QuickReplyButton( action=PostbackAction(label='1st', data='1', display_text='1st') ), QuickReplyButton( action=PostbackAction(label='2nd', data='2', display_text='2nd') ), QuickReplyButton( action=PostbackAction(label='3rd', data='3', display_text='3rd') ) ]))) elif result.get('question') == '2': entity.update({ 'question': '3', 'Pclass': result.get('Pclass'), 'Sex': result.get('Sex'), 'Age': event.message.text, }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage(text='乗船している兄弟・配偶者の人数を教えてください。')) elif result.get('question') == '3': entity.update({ 'question': '4', 'Pclass': result.get('Pclass'), 'Sex': result.get('Sex'), 'Age': result.get('Age'), 'SibSp': event.message.text }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage(text='乗船している両親・子供の人数を教えてください。')) elif result.get('question') == '4': entity.update({ 'question': '5', 'Pclass': result.get('Pclass'), 'Sex': result.get('Sex'), 'Age': result.get('Age'), 'SibSp': result.get('SibSp'), 'Parch': event.message.text }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage( text='乗船した港はどれですか?', quick_reply=QuickReply( items=[ QuickReplyButton( action=PostbackAction(label='Cherbourg', data='C', display_text='Cherbourg') ), QuickReplyButton( action=PostbackAction(label='Queenstown', data='Q', display_text='Queenstown') ), QuickReplyButton( action=PostbackAction(label='Southampton', data='S', display_text='Southampton') ) ]))) else: line_bot_api.reply_message( event.reply_token, TextSendMessage( text='『予測』とメッセージを送ってみてください。')) @handler.add(PostbackEvent) def handle_postback(event): key = client.key('Titanic', event.source.user_id) entity = datastore.Entity(key=key) result = client.get(key) app.logger.info(result) if result.get('question') == '0': entity.update({ 'question': '1', 'Pclass': event.postback.data, }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage( text='性別はどちらですか?', quick_reply=QuickReply( items=[ QuickReplyButton( action=PostbackAction(label='男性', data='male', display_text='男性') ), QuickReplyButton( action=PostbackAction(label='女性', data='female', display_text='女性') ) ]))) elif result.get('question') == '1': entity.update({ 'question': '2', 'Pclass': result.get('Pclass'), 'Sex': event.postback.data, }) client.put(entity) line_bot_api.reply_message( event.reply_token, TextSendMessage(text='年齢はいくつですか?')) elif result.get('question') == '5': entity.update({ 'question': '5', 'Pclass': result.get('Pclass'), 'Sex': result.get('Sex'), 'Age': result.get('Age'), 'SibSp': result.get('SibSp'), 'Parch': result.get('Parch'), 'Embarked': event.postback.data, }) client.put(entity) inputs_for_prediction = [ {'Pclass':result.get('Pclass'), 'Sex':result.get('Sex'), 'Age':result.get('Age'), 'SibSp':result.get('SibSp'), 'Parch':result.get('Parch'), 'Embarked':event.postback.data} ] credentials = GoogleCredentials.from_stream(sa_keyfile) service = discovery.build('ml', 'v1', credentials=credentials) name = 'projects/{}/models/{}'.format(PROJECT_ID, MODEL_NAME) response = service.projects().predict( name=name, body={'instances': inputs_for_prediction} ).execute() predicted_survived = response['predictions'][0]['predicted_Survived'] if predicted_survived == '1': line_bot_api.reply_message( event.reply_token, [TextSendMessage( text='安心してください。あなたは無事に帰ってこれるでしょう。'), StickerSendMessage( package_id='11537', sticker_id='52002735')]) elif predicted_survived == '0': line_bot_api.reply_message( event.reply_token, [TextSendMessage( text='あなたには困難な運命が待ち受けている...かもしれません...'), StickerSendMessage( package_id='11537', sticker_id='52002755')]) else: line_bot_api.reply_message( event.reply_token, TextSendMessage(text='『予測』とメッセージを送ってみてください。')) if __name__ == '__main__': app.run()
- Dockerfile
# Use the official Python image. # https://hub.docker.com/_/python FROM python:3.7 # Copy local code to the container image. ENV APP_HOME /app WORKDIR $APP_HOME COPY . . # Install production dependencies. RUN pip install Flask gunicorn line-bot-sdk google-cloud-datastore oauth2client google-api-python-client # Run the web service on container startup. Here we use the gunicorn # webserver, with one worker process and 8 threads. # For environments with multiple CPU cores, increase the number of workers # to be equal to the cores available. CMD exec gunicorn --bind :$PORT --workers 1 --threads 8 app:app
GitHub Codespacesの拡張機能にはCloud runのデプロイなどを操作できるCloud Codeがあるのでそちらを使ってデプロイを実施します。
アプリへの組み込み完了しました。最近のテクノロジーを駆使するとアプリ開発はすごく楽になりますね!それでは本題のアプリから予測を実行していきましょう。
下記は私っぽい情報を入れたところです...ちょっとダメそうな予測結果が返ってきました...
ちなみに妻は...
おお!良かった!無事のようですε-(´∀`*)ホッ
まとめ
BQMLで作成したモデルをオンライン予測するまでの一通りの流れを見ていただきました。いちおう簡単にですがアプリっぽいものへの組み込みまでやっていきました。BQMLも機械学習の民主化を強く感じるものでしたが、昨今の開発の多くが非常にハードルが下がっていて頭にイメージしたものがクイックに実現できる素晴らしい世の中になっているなぁと感動もしておりました。ぜひ皆さまもBigQuery、機械学習を含む自分のやりたいことをこの冬休みに考えたり実行してみたりしてはいかがでしょうか?その際に本アドベントカレンダーが皆さまの助けになれれば幸いです。
参考
- https://cloud.google.com/bigquery-ml/docs/export-model-tutorial#train_and_deploy_a_boosted_tree_classifier_model
- https://cloud.google.com/ai-platform/prediction/docs/machine-types-online-prediction
- https://cloud.google.com/ai-platform/prediction/docs/getting-predictions-xgboost
- https://cloud.google.com/ai-platform/pricing
- https://cloud.google.com/ai-platform/prediction/pricing
- https://cloud.google.com/ai-platform/prediction/docs/online-predict
- https://cloud.google.com/blog/ja/products/data-analytics/new-ml-models-built-into-cloud-data-analytics